import torch
from llava_v15.conversation import conv_llava_v1
from llava_v15.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava_v15.mm_utils import tokenizer_image_token
import re


def prepare_text_prompt(user_prompt):

    qs = DEFAULT_IMAGE_TOKEN + '\n' + user_prompt

    conv = conv_llava_v1.copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    return prompt

# support batch implementation
class Prompt:
    # tokenization
    # turn to embeddings

    # padding? wait until targets have been appended
    # prepare labels? need to wait for targets

    def __init__(self, model, tokenizer, text_prompts=None, device='cuda:0'):

        self.model = model
        self.tokenizer = tokenizer
        self.device = device

        self.text_prompts = text_prompts
        self.text_prompts_processed = []
        self.context_length = []
        self.input_ids = []

        self.offset=[]
        self.adv_len=0
        self.adv_suffix_tokens=None

        self.do_tokenization(self.text_prompts)


    def do_tokenization(self, text_prompts):

        if text_prompts is None:
            self.input_ids = []
            self.context_length = []
            return
        
        for item in text_prompts:

            prompt_segs =re.split('<image>|<adv_split_start>|</adv_split_end>',item)
            seg_tokens = [
                self.tokenizer(
                    seg, return_tensors="pt", add_special_tokens=False).to(self.device).input_ids

                for i, seg in enumerate(prompt_segs)
            ]
            item_offset=seg_tokens[0].shape[1]+seg_tokens[1].shape[1]+576
            self.offset.append(item_offset)

            self.adv_len=seg_tokens[2].shape[1]

            self.adv_suffix_tokens=seg_tokens[2].squeeze()

            item_ = prompt_segs[0] + '<image>' + prompt_segs[1] + prompt_segs[2] + prompt_segs[3]
            self.text_prompts_processed.append(item_)

            input_ids = tokenizer_image_token(item_, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

            self.input_ids.append(input_ids)
            self.context_length.append(input_ids.shape[1])